import argparse
import cv2
import sys
import numpy as np
import os
import mxnet as mx
import datetime
from skimage import transform as trans
import insightface

def square_crop(im, S):
  if im.shape[0]>im.shape[1]:
    height = S
    width = int( float(im.shape[1]) / im.shape[0] * S )
    scale = float(S) / im.shape[0]
  else:
    width = S
    height = int( float(im.shape[0]) / im.shape[1] * S )
    scale = float(S) / im.shape[1]
  resized_im = cv2.resize(im, (width, height))
  det_im = np.zeros( (S, S, 3), dtype=np.uint8 )
  det_im[:resized_im.shape[0], :resized_im.shape[1], :] = resized_im
  return det_im, scale

def transform(data, center, output_size, scale, rotation):
    scale_ratio = scale
    rot = float(rotation)*np.pi/180.0
    #translation = (output_size/2-center[0]*scale_ratio, output_size/2-center[1]*scale_ratio)
    t1 = trans.SimilarityTransform(scale=scale_ratio)
    cx = center[0]*scale_ratio
    cy = center[1]*scale_ratio
    t2 = trans.SimilarityTransform(translation=(-1*cx, -1*cy))
    t3 = trans.SimilarityTransform(rotation=rot)
    t4 = trans.SimilarityTransform(translation=(output_size/2, output_size/2))
    t = t1+t2+t3+t4
    M = t.params[0:2]
    cropped = cv2.warpAffine(data,M,(output_size, output_size), borderValue = 0.0)
    return cropped, M

def trans_points2d(pts, M):
    new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
    for i in range(pts.shape[0]):
        pt = pts[i]
        new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
        new_pt = np.dot(M, new_pt)
        #print('new_pt', new_pt.shape, new_pt)
        new_pts[i] = new_pt[0:2]

    return new_pts

def trans_points3d(pts, M):
    scale = np.sqrt(M[0][0]*M[0][0] + M[0][1]*M[0][1])
    #print(scale)
    new_pts = np.zeros(shape=pts.shape, dtype=np.float32)
    for i in range(pts.shape[0]):
        pt = pts[i]
        new_pt = np.array([pt[0], pt[1], 1.], dtype=np.float32)
        new_pt = np.dot(M, new_pt)
        #print('new_pt', new_pt.shape, new_pt)
        new_pts[i][0:2] = new_pt[0:2]
        new_pts[i][2] = pts[i][2]*scale

    return new_pts

def trans_points(pts, M):
  if pts.shape[1]==2:
    return trans_points2d(pts, M)
  else:
    return trans_points3d(pts, M)


class Handler:
  def __init__(self, prefix, epoch, im_size=192, det_size=224, ctx_id=0):
    print('loading',prefix, epoch)
    if ctx_id>=0:
      ctx = mx.gpu(ctx_id)
    else:
      ctx = mx.cpu()
    image_size = (im_size, im_size)
    self.detector = insightface.model_zoo.get_model('retinaface_mnet025_v2') #can replace with your own face detector
    #self.detector = insightface.model_zoo.get_model('retinaface_r50_v1')
    self.detector.prepare(ctx_id=ctx_id)
    self.det_size = det_size
    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
    all_layers = sym.get_internals()
    sym = all_layers['fc1_output']
    self.image_size = image_size
    model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)
    model.bind(for_training=False, data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])
    model.set_params(arg_params, aux_params)
    self.model = model
    self.image_size = image_size


  
  def get(self, img, get_all=False):
    out = []
    det_im, det_scale = square_crop(img, self.det_size)
    bboxes, _ = self.detector.detect(det_im)
    if bboxes.shape[0]==0:
        return out
    bboxes /= det_scale
    if not get_all:
      areas = []
      for i in range(bboxes.shape[0]):
        x = bboxes[i]
        area = (x[2]-x[0])*(x[3]-x[1])
        areas.append(area)
      m = np.argsort(areas)[-1]
      bboxes = bboxes[m:m+1]
    for i in range(bboxes.shape[0]):
      bbox = bboxes[i]
      input_blob = np.zeros( (1, 3)+self.image_size,dtype=np.float32)
      w, h = (bbox[2]-bbox[0]), (bbox[3]-bbox[1])
      center = (bbox[2]+bbox[0])/2, (bbox[3]+bbox[1])/2
      rotate = 0
      _scale = self.image_size[0]*2/3.0/max(w,h)
      rimg, M = transform(img, center, self.image_size[0], _scale, rotate)
      rimg = cv2.cvtColor(rimg, cv2.COLOR_BGR2RGB)
      rimg = np.transpose(rimg, (2,0,1)) #3*112*112, RGB
      input_blob[0] = rimg
      data = mx.nd.array(input_blob)
      db = mx.io.DataBatch(data=(data,))
      self.model.forward(db, is_train=False)
      pred = self.model.get_outputs()[-1].asnumpy()[0]
      if pred.shape[0]>=3000:
        pred = pred.reshape( (-1, 3) )
      else:
        pred = pred.reshape( (-1, 2) )
      pred[:,0:2] += 1
      pred[:,0:2] *= (self.image_size[0]//2)
      if pred.shape[1]==3:
        pred[:,2] *= (self.image_size[0]//2)

      IM = cv2.invertAffineTransform(M)
      pred = trans_points(pred, IM)
      out.append(pred)
    return out


if __name__ == '__main__':
  handler = Handler('./model/2d106_det', 0, ctx_id=7, det_size=640)
  im = cv2.imread('../../sample-images/t1.jpg')
  tim = im.copy()
  preds = handler.get(im, get_all=True)
  color = (200, 160, 75)
  for pred in preds:
    pred = np.round(pred).astype(np.int)
    for i in range(pred.shape[0]):
      p = tuple(pred[i])
      cv2.circle(tim, p, 1, color, 1,cv2.LINE_AA)
  cv2.imwrite('./test_out.jpg', tim)


